package com.jstarcraft.ai.jsat.math.optimization;

import java.util.ArrayList;
import java.util.List;

import com.jstarcraft.ai.jsat.linear.IndexValue;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.math.Function;
import com.jstarcraft.ai.jsat.math.FunctionVec;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;

/**
 * Implementation of the Limited memory variant of {@link BFGS}. It uses a
 * history of {@link #setM(int) m} items to solve {@code n} dimension problems
 * with {@code O(m n)} work per iteration.
 * 
 * @author Edward Raff
 */
public class LBFGS implements Optimizer {
    private int m;
    private int maxIterations;
    private LineSearch lineSearch;
    private boolean inftNormCriterion = true;

    /**
     * Creates a new L-BFGS optimization object that uses a maximum of 500
     * iterations and a {@link BacktrackingArmijoLineSearch Backtracking} line
     * search. A {@link #setM(int) history} of 10 items will be used
     */
    public LBFGS() {
        this(10);
    }

    /**
     * Creates a new L-BFGS optimization object that uses a maximum of 500
     * iterations and a {@link BacktrackingArmijoLineSearch Backtracking} line
     * search.
     * 
     * @param m the number of history items
     */
    public LBFGS(int m) {
        this(m, 500, new BacktrackingArmijoLineSearch());
    }

    /**
     * Creates a new L-BFGS optimization object
     * 
     * @param m             the number of history items
     * @param maxIterations the maximum number of iterations before stopping
     * @param lineSearch    the line search method to use for optimization
     */
    public LBFGS(int m, int maxIterations, LineSearch lineSearch) {
        setM(m);
        setMaximumIterations(maxIterations);
        setLineSearch(lineSearch);
    }

    /**
     * See Algorithm 7.4 (L-BFGS two-loop recursion).
     * 
     * @param x_grad the initial value &nabla; f<sub>k</sub>
     * @param rho
     * @param s
     * @param y
     * @param q      the location to store the value of H<sub>k</sub> &nabla;
     *               f<sub>k</sub>
     * @param alphas temp space to do work, should be as large as the number of
     *               history vectors
     */
    public static void twoLoopHp(Vec x_grad, List<Double> rho, List<Vec> s, List<Vec> y, Vec q, double[] alphas) {
        // q ← ∇ fk;
        x_grad.copyTo(q);
        if (s.isEmpty())
            return;// identity, we are done
        // for i = k−1,k−2,...,k−m
        for (int i = 0; i < s.size(); i++) {
            Vec s_i = s.get(i);
            Vec y_i = y.get(i);
            double alpha_i = alphas[i] = rho.get(i) * s_i.dot(q);
            q.mutableSubtract(alpha_i, y_i);
        }

        // r ← Hk0q; and see eq (7.20), done in place in q
        q.mutableMultiply(s.get(0).dot(y.get(0)) / y.get(0).dot(y.get(0)));
        // for i = k−m,k−m+1,...,k−1
        for (int i = s.size() - 1; i >= 0; i--) {
            // β ← ρ_i y_i^T r ;
            double beta = rho.get(i) * y.get(i).dot(q);
            // r ← r + si (αi − β)
            q.mutableAdd(alphas[i] - beta, s.get(i));
        }
    }

    @Override
    public void optimize(double tolerance, Vec w, Vec x0, Function f, FunctionVec fp, boolean parallel) {
        if (fp == null)
            fp = Function.forwardDifference(f);
        LineSearch search = lineSearch.clone();
        final double[] f_xVal = new double[1];// store place for f_x

        // history for implicit H
        DoubleArrayList Rho = new DoubleArrayList(m);
        List<Vec> S = new ArrayList<>(m);
        List<Vec> Y = new ArrayList<>(m);

        Vec x_prev = x0.clone();
        Vec x_cur = x0.clone();
        f_xVal[0] = f.f(x_prev, parallel);
        // graidnet
        Vec x_grad = x0.clone();
        x_grad.zeroOut();
        Vec x_gradPrev = x_grad.clone();
        // p_l
        Vec p_k = x_grad.clone();
        Vec s_k = x_grad.clone();
        Vec y_k = x_grad.clone();

        x_grad = fp.f(x_cur, x_grad, parallel);

        double[] alphas = new double[m];
        int iter = 0;

        while (gradConvgHelper(x_grad) > tolerance && iter < maxIterations) {
            // p_k = −H_k ∇f_k; (6.18)
            twoLoopHp(x_grad, Rho, S, Y, p_k, alphas);
            p_k.mutableMultiply(-1);

            // Set x_k+1 = x_k + α_k p_k where α_k is computed from a line search
            x_cur.copyTo(x_prev);
            x_grad.copyTo(x_gradPrev);

            double alpha_k = search.lineSearch(1.0, x_prev, x_gradPrev, p_k, f, fp, f_xVal[0], x_gradPrev.dot(p_k), x_cur, f_xVal, x_grad, parallel);
            if (alpha_k < 1e-12)// if we are making near epsilon steps consider it done
                break;

            if (!search.updatesGrad())
                fp.f(x_cur, x_grad, parallel);

            // Define s_k =x_k+1 −x_k and y_k = ∇f_k+1 −∇f_k;
            x_cur.copyTo(s_k);
            s_k.mutableSubtract(x_prev);
            S.add(0, s_k.clone());

            x_grad.copyTo(y_k);
            y_k.mutableSubtract(x_gradPrev);
            Y.add(0, y_k.clone());

            Rho.add(0, 1 / s_k.dot(y_k));
            if (Double.isInfinite(Rho.get(0)) || Double.isNaN(Rho.get(0))) {
                Rho.clear();
                S.clear();
                Y.clear();
            }
            while (Rho.size() > m) {
                Rho.remove(m);
                S.remove(m);
                Y.remove(m);
            }

            iter++;
        }

        x_cur.copyTo(w);
    }

    /**
     * By default the infinity norm is used to judge convergence. If set to
     * {@code false}, the 2 norm will be used instead.
     * 
     * @param inftNormCriterion
     */
    public void setInftNormCriterion(boolean inftNormCriterion) {
        this.inftNormCriterion = inftNormCriterion;
    }

    /**
     * Returns whether or not the infinity norm ({@code true}) or 2 norm
     * ({@code false}) is used to determine convergence.
     * 
     * @return {@code true} if the infinity norm is in use, {@code false} for the 2
     *         norm
     */
    public boolean isInftNormCriterion() {
        return inftNormCriterion;
    }

    private double gradConvgHelper(Vec grad) {
        if (!inftNormCriterion)
            return grad.pNorm(2);
        double max = 0;
        for (IndexValue iv : grad)
            max = Math.max(max, Math.abs(iv.getValue()));
        return max;
    }

    /**
     * Sets the number of history items to keep that are used to approximate the
     * Hessian of the problem
     * 
     * @param m the number of history items to keep
     */
    public void setM(int m) {
        if (m < 1)
            throw new IllegalArgumentException("m must be positive, not " + m);
        this.m = m;
    }

    /**
     * Returns the number of history items that will be used
     * 
     * @return the number of history items that will be used
     */
    public int getM() {
        return m;
    }

    /**
     * Sets the line search method used at each iteration
     * 
     * @param lineSearch the line search method used at each iteration
     */
    public void setLineSearch(LineSearch lineSearch) {
        this.lineSearch = lineSearch;
    }

    /**
     * Returns the line search method used at each iteration
     * 
     * @return the line search method used at each iteration
     */
    public LineSearch getLineSearch() {
        return lineSearch;
    }

    @Override
    public void setMaximumIterations(int iterations) {
        if (iterations < 1)
            throw new IllegalArgumentException("Number of iterations must be positive, not " + iterations);
        this.maxIterations = iterations;
    }

    @Override
    public int getMaximumIterations() {
        return maxIterations;
    }

    @Override
    public LBFGS clone() {
        return new LBFGS(m, maxIterations, lineSearch.clone());
    }

}
